#include <flashinfer/attention/prefill.cuh>
#include "batch_prefill_config.inc"

namespace flashinfer {

constexpr auto use_custom_mask = {{ mask_mode }} == MaskMode::kCustom;

{% for cta_tile_q in [16, 64, 128] %}
template cudaError_t BatchPrefillWithPagedKVCacheDispatched<
    /*CTA_TILE_Q=*/{{cta_tile_q}}, {{head_dim_qk}}, {{head_dim_vo}}, {{pos_encoding_mode}}, {{use_fp16_qk_reduction}}, {{mask_mode}},
    {{ variant_name }}, PagedParams>(PagedParams params, {{ dtype_o }}* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream);
{% endfor %}

};  // namespace flashinfer
